# Rekall Memory Forensics
#
# Copyright (c) 2010 - 2012 Michael Ligh <michael.ligh@mnin.org>
# Copyright 2013 Google Inc. All Rights Reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at
# your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#

from rekall import plugin
from rekall import obj
from rekall import testlib

from rekall.plugins.overlays.windows import pe_vtypes
from rekall.plugins.windows import common


class ImpScan(common.WinProcessFilter):
    """Scan for calls to imported functions."""

    __name = "impscan"

    FORWARDED_IMPORTS = {
        "RtlGetLastWin32Error" : "kernel32.dll!GetLastError",
        "RtlSetLastWin32Error" : "kernel32.dll!SetLastError",
        "RtlRestoreLastWin32Error" : "kernel32.dll!SetLastError",
        "RtlAllocateHeap" : "kernel32.dll!HeapAlloc",
        "RtlReAllocateHeap" : "kernel32.dll!HeapReAlloc",
        "RtlFreeHeap" : "kernel32.dll!HeapFree",
        "RtlEnterCriticalSection" : "kernel32.dll!EnterCriticalSection",
        "RtlLeaveCriticalSection" : "kernel32.dll!LeaveCriticalSection",
        "RtlDeleteCriticalSection" : "kernel32.dll!DeleteCriticalSection",
        "RtlZeroMemory" : "kernel32.dll!ZeroMemory",
        "RtlSizeHeap" : "kernel32.dll!HeapSize",
        "RtlUnwind" : "kernel32.dll!RtlUnwind",
        }

    @classmethod
    def args(cls, parser):
        """Declare the command line args we need."""
        super(ImpScan, cls).args(parser)
        parser.add_argument("-b", "--base", default=None, type="IntParser",
                            help="Base address in process memory if --pid is "
                            "supplied, otherwise an address in kernel space")

        parser.add_argument("-s", "--size", default=None, type="IntParser",
                            help="Size of memory to scan")

        parser.add_argument("-k", "--kernel", default=None, type="Boolean",
                            help="Scan in kernel space.")

    def __init__(self, base=None, size=None, kernel=None, **kwargs):
        """Scans the imports from a module.

        Often when dumping a PE executable from memory the import address tables
        are over written. This makes it hard to resolve function names when
        disassembling the binary.

        This plugin enumerates all dlls in the process address space and
        examines their export address tables. It then disassembles the
        executable code for calls to external functions. We attempt to resolve
        the names of the calls using the known exported functions we gathered in
        step 1.

        This technique can be used for a process, or the kernel itself. In the
        former case, we examine dlls, while in the later case we examine kernel
        modules using the modules plugin.

        Args:

          base: Start disassembling at this address - this is normally the base
            address of the dll or module we care about. If omitted we use the
            kernel base (if in kernel mode) or the main executable (if in
            process mode).

          size: Disassemble this many bytes from the address space. If omitted
            we use the module which starts at base.

          kernel: The mode to use. If set, we operate in kernel mode.
        """
        super(ImpScan, self).__init__(**kwargs)
        self.base = base
        self.size = size
        self.kernel = kernel

    def _enum_apis(self, all_mods):
        """Enumerate all exported functions from kernel
        or process space.

        @param all_mods: list of _LDR_DATA_TABLE_ENTRY

        To enum kernel APIs, all_mods is a list of drivers.
        To enum process APIs, all_mods is a list of DLLs.

        The function name is used if available, otherwise
        we take the ordinal value.
        """
        exports = {}

        for i, mod in enumerate(all_mods):
            self.session.report_progress("Scanning imports %s/%s" % (
                i, len(all_mods)))

            pe = pe_vtypes.PE(address_space=mod.obj_vm,
                              session=self.session, image_base=mod.DllBase)

            for _, func_pointer, func_name, ordinal in pe.ExportDirectory():
                function_name = func_name or ordinal or ''

                exports[func_pointer.v()] = (mod, func_pointer, function_name)

        return exports

    def _iat_scan(self, addr_space, calls_imported, apis, base_address,
                  end_address):
        """Scan forward from the lowest IAT entry found for new import entries.

        Args:
          addr_space: an AS
          calls_imported: Import database - a dict.
          apis: dictionary of exported functions in the AS.
          base_address: memory base address for this module.
          end_address: end of valid address range.
        """
        if not calls_imported:
            return

        # Search the iat from the earliest function address to the latest
        # address for references to other functions.
        start_addr = min(calls_imported.keys())
        iat_size = min(max(calls_imported.keys()) - start_addr, 2000)

        # The IAT is a table of pointers to functions.
        iat = self.profile.Array(
            offset=start_addr, vm=addr_space, target="Pointer",
            count=iat_size, target_args=dict(target="Function"))

        for func_pointer in iat:
            func = func_pointer.dereference()

            if (not func or
                    (func.obj_offset > base_address and
                     func.obj_offset < end_address)): # skip call to self
                continue

            # Add the export to our database of imported calls.
            if (func.obj_offset in apis and
                    func_pointer.obj_offset not in calls_imported):
                iat_addr = func_pointer.obj_offset
                calls_imported[iat_addr] = (iat_addr, func)

    def _original_import(self, mod_name, func_name):
        """Revert a forwarded import to the original module
        and function name.

        @param mod_name: current module name
        @param func_name: current function name
        """

        if func_name in self.FORWARDED_IMPORTS:
            return self.FORWARDED_IMPORTS[func_name].split("!")
        else:
            return mod_name, func_name

    CALL_RULE = {'mnemonic': 'CALL', 'operands': [
        {'type': 'MEM', 'target': "$target", 'address': '$address'}]}
    JMP_RULE = {'mnemonic': 'JMP', 'operands': [
        {'type': 'MEM', 'target': "$target", 'address': '$address'}]}

    def call_scan(self, addr_space, base_address, size_to_read):
        """Locate calls in a block of code.

        Disassemble a block of data and yield possible calls to imported
        functions.  We're looking for instructions such as these:

        x86:
        CALL DWORD [0x1000400]
        JMP  DWORD [0x1000400]

        x64:
        CALL QWORD [RIP+0x989d]

        On x86, the 0x1000400 address is an entry in the IAT or call table. It
        stores a DWORD which is the location of the API function being called.

        On x64, the 0x989d is a relative offset from the current instruction
        (RIP).

        So we simply disassemble the entire code section of the executable
        looking for calls, then we collect all the targets of the calls.

        @param addr_space: an AS to scan with
        @param base_address: memory base address
        @param data: buffer of data found at base_address

        """
        func_obj = self.profile.Function(vm=addr_space, offset=base_address)
        end_address = base_address + size_to_read

        for instruction in func_obj.disassemble(2**32):
            if instruction.address > end_address:
                break

            context = {}
            if (instruction.match_rule(self.CALL_RULE, context) or
                    instruction.match_rule(self.JMP_RULE, context)):
                target = context.get("$target")
                if target:
                    yield (instruction.address,
                           context.get("$address"),
                           self.profile.Function(vm=addr_space, offset=target))

    def find_process_imports(self, task):
        task_space = task.get_process_address_space()
        all_mods = list(task.get_load_modules())

        # Exported function of all other modules in the address space.
        apis = self._enum_apis(all_mods)

        # PEB is paged or no DLLs loaded
        if not all_mods:
            self.session.logging.error("Cannot load DLLs in process AS")
            return

        # Its OK to blindly take the 0th element because the executable is
        # always the first module to load.
        base_address = int(all_mods[0].DllBase)
        size_to_read = int(all_mods[0].SizeOfImage)

        calls_imported = {}
        for address, iat, destination in self.call_scan(
                task_space, base_address, size_to_read):
            self.session.report_progress("Resolving import %s->%s" % (
                address, iat))
            calls_imported[iat] = (address, destination)

        # Scan the IAT for additional functions.
        self._iat_scan(task_space, calls_imported, apis,
                       base_address, base_address + size_to_read)

        for iat, (_, func_pointer) in sorted(calls_imported.items()):
            tmp = apis.get(func_pointer.obj_offset)
            if tmp:
                module, func_pointer, func_name = tmp
                yield iat, func_pointer, module, func_name

    def find_kernel_import(self):
        # If the user has not specified the base, we just use the kernel's
        # image.
        base_address = self.base
        if base_address is None:
            base_address = self.session.GetParameter("kernel_base")

        # Get the size from the module list if its not supplied
        size_to_read = self.size
        if not size_to_read:
            modlist = self.session.plugins.modules()
            for module in modlist.lsmod():
                if module.DllBase == base_address:
                    size_to_read = module.SizeOfImage
                    break

        if not size_to_read:
            raise plugin.PluginError("You must specify a size to read.")

        all_mods = list(modlist.lsmod())
        apis = self._enum_apis(all_mods)

        calls_imported = {}
        for address, iat, destination in self.call_scan(
                self.kernel_address_space, base_address, size_to_read):
            calls_imported[iat] = (address, destination)
            self.session.report_progress(
                "Found %s imports" % len(calls_imported))

        # Scan the IAT for additional functions.
        self._iat_scan(self.kernel_address_space, calls_imported, apis,
                       base_address, size_to_read)

        for iat, (address, func_pointer) in sorted(calls_imported.items()):
            module, func_pointer, func_name = apis.get(func_pointer.v(), (
                obj.NoneObject("Unknown"),
                obj.NoneObject("Unknown"),
                obj.NoneObject("Unknown")))

            yield iat, func_pointer, module, func_name

    def render(self, renderer):
        table_header = [("IAT", 'iat', "[addrpad]"),
                        ("Call", 'call', "[addrpad]"),
                        ("Module", 'moduole', "20"),
                        ("Function", 'function', ""),
                       ]

        if self.kernel:
            renderer.format("Kernel Imports\n")

            renderer.table_header(table_header)
            for iat, func, mod, func_name in self.find_kernel_import():
                mod_name, func_name = self._original_import(
                    mod.BaseDllName, func_name)

                renderer.table_row(iat, func, mod_name, func_name)
        else:
            for task in self.filter_processes():
                renderer.section()
                renderer.format("Process {0} PID {1}\n", task.ImageFileName,
                                task.UniqueProcessId)
                renderer.table_header(table_header)

                for iat, func, mod, func_name in self.find_process_imports(
                        task):
                    mod_name, func_name = self._original_import(
                        mod.BaseDllName, func_name)
                    renderer.table_row(iat, func, mod_name, func_name)


class TestImpScan(testlib.SimpleTestCase):
    """Test the impscan module."""

    PARAMETERS = dict(commandline="impscan %(pids)s")
